{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# 自动调优 VTA 上的 ALU fused op"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "tags": [
          "remove-cell"
        ]
      },
      "outputs": [],
      "source": [
        "import set_env"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import os\n",
        "from mxnet.gluon.model_zoo import vision\n",
        "import numpy as np\n",
        "from PIL import Image\n",
        "\n",
        "from tvm import topi\n",
        "import tvm\n",
        "from tvm import te\n",
        "from tvm import rpc, autotvm, relay\n",
        "from tvm.contrib import download\n",
        "from tvm.autotvm.measure.measure_methods import request_remote\n",
        "from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner\n",
        "from tvm.autotvm import record\n",
        "\n",
        "import vta\n",
        "from vta.testing import simulator\n",
        "from vta.top import graph_pack\n",
        "import copy"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 编译网络\n",
        "\n",
        "从 Gluon 模型使用 Relay 执行特定于 vta 的编译："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def compile_network(env, target, model, start_pack, stop_pack):\n",
        "    # Populate the shape and data type dictionary\n",
        "    dtype_dict = {\"data\": \"float32\"}\n",
        "    shape_dict = {\"data\": (env.BATCH, 3, 224, 224)}\n",
        "\n",
        "    # Get off the shelf gluon model, and convert to relay\n",
        "    gluon_model = vision.get_model(model, pretrained=True)\n",
        "    mod, params = relay.frontend.from_mxnet(gluon_model, shape_dict)\n",
        "\n",
        "    # Update shape and type dictionary\n",
        "    shape_dict.update({k: v.shape for k, v in params.items()})\n",
        "    dtype_dict.update({k: str(v.dtype) for k, v in params.items()})\n",
        "\n",
        "    # Perform quantization in Relay\n",
        "    # Note: We set opt_level to 3 in order to fold batch norm\n",
        "    with relay.build_config(opt_level=3):\n",
        "        with relay.quantize.qconfig(global_scale=8.0, skip_conv_layers=[0]):\n",
        "            mod = relay.quantize.quantize(mod, params=params)\n",
        "\n",
        "    # Perform graph packing and constant folding for VTA target\n",
        "    if target.device_name == \"vta\":\n",
        "        assert env.BLOCK_IN == env.BLOCK_OUT\n",
        "        relay_prog = graph_pack(\n",
        "            mod[\"main\"],\n",
        "            env.BATCH,\n",
        "            env.BLOCK_OUT,\n",
        "            env.WGT_WIDTH,\n",
        "            start_name=start_pack,\n",
        "            stop_name=stop_pack,\n",
        "        )\n",
        "\n",
        "    return relay_prog, params"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 设置调优选项\n",
        "\n",
        "在调优之前，应该应用一些配置。这里以 Pynq-Z1 板为例。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Tracker host and port can be set by your environment\n",
        "tracker_host = os.environ.get(\"TVM_TRACKER_HOST\", \"0.0.0.0\")\n",
        "tracker_port = int(os.environ.get(\"TVM_TRACKER_PORT\", 9190))\n",
        "\n",
        "# Load VTA parameters from the vta/config/vta_config.json file\n",
        "env = vta.get_env()\n",
        "\n",
        "# This target is used for cross compilation. You can query it by :code:`gcc -v` on your device.\n",
        "# Set ``device=arm_cpu`` to run inference on the CPU\n",
        "# or ``device=vta`` to run inference on the FPGA.\n",
        "device = \"vta\"\n",
        "target = env.target if device == \"vta\" else env.target_vta_cpu\n",
        "\n",
        "# Name of Gluon model to compile\n",
        "# The ``start_pack`` and ``stop_pack`` labels indicate where\n",
        "# to start and end the graph packing relay pass: in other words\n",
        "# where to start and finish offloading to VTA.\n",
        "network = \"resnet50_v2\"\n",
        "start_pack = \"nn.max_pool2d\"\n",
        "stop_pack = \"nn.global_avg_pool2d\"\n",
        "\n",
        "# Tuning option\n",
        "log_file = \"%s.alu.%s.log\" % (device, network)\n",
        "tuning_option = {\n",
        "    \"log_filename\": log_file,\n",
        "    \"tuner\": \"random\",\n",
        "    \"n_trial\": 1000,\n",
        "    \"early_stopping\": None,\n",
        "    \"measure_option\": autotvm.measure_option(\n",
        "        builder=autotvm.LocalBuilder(n_parallel=1),\n",
        "        runner=autotvm.RPCRunner(\n",
        "            env.TARGET,\n",
        "            host=tracker_host,\n",
        "            port=tracker_port,\n",
        "            number=5,\n",
        "            timeout=60,\n",
        "            # check_correctness=True, # TODO: re-enable when check_correctness works again.\n",
        "        ),\n",
        "    ),\n",
        "}\n",
        "\n",
        "\n",
        "def log_to_file(file_out, protocol=\"json\"):\n",
        "    \"\"\"Log the tuning records into file.\n",
        "    The rows of the log are stored in the format of autotvm.record.encode.\n",
        "    for lhs == rhs, we add an extra rhs = [] record\n",
        "\n",
        "    Parameters\n",
        "    ----------\n",
        "    file_out : str\n",
        "        The file to log to.\n",
        "    protocol: str, optional\n",
        "        The log protocol. Can be 'json' or 'pickle'\n",
        "\n",
        "    Returns\n",
        "    -------\n",
        "    callback : callable\n",
        "        Callback function to do the logging.\n",
        "    \"\"\"\n",
        "\n",
        "    def _callback(_, inputs, results):\n",
        "        with open(file_out, \"a\") as f:\n",
        "            for inp, result in zip(inputs, results):\n",
        "                f.write(record.encode(inp, result, protocol) + \"\\n\")\n",
        "\n",
        "                # we only consider task with same lhs and rhs\n",
        "                if inp.task.args[0] == inp.task.args[1]:\n",
        "                    args = list(inp.task.args)\n",
        "                    args[1] = (args[0][0], (), args[0][2])\n",
        "                    inp_copy = copy.deepcopy(inp)\n",
        "                    inp_copy.task.args = tuple(args)\n",
        "                    f.write(record.encode(inp_copy, result, protocol) + \"\\n\")\n",
        "\n",
        "    return _callback\n",
        "\n",
        "\n",
        "def tune_tasks(\n",
        "    tasks,\n",
        "    measure_option,\n",
        "    tuner=\"xgb\",\n",
        "    n_trial=10,\n",
        "    early_stopping=None,\n",
        "    log_filename=\"tuning.log\",\n",
        "    use_transfer_learning=True,\n",
        "):\n",
        "\n",
        "    # create tmp log file\n",
        "    tmp_log_file = log_filename + \".tmp\"\n",
        "    if os.path.exists(tmp_log_file):\n",
        "        os.remove(tmp_log_file)\n",
        "\n",
        "    for i, tsk in enumerate(reversed(tasks)):\n",
        "        prefix = \"[Task %2d/%2d] \" % (i + 1, len(tasks))\n",
        "\n",
        "        # create tuner\n",
        "        if tuner == \"xgb\" or tuner == \"xgb-rank\":\n",
        "            tuner_obj = XGBTuner(tsk, loss_type=\"rank\")\n",
        "        elif tuner == \"xgb_knob\":\n",
        "            tuner_obj = XGBTuner(tsk, loss_type=\"rank\", feature_type=\"knob\")\n",
        "        elif tuner == \"ga\":\n",
        "            tuner_obj = GATuner(tsk, pop_size=50)\n",
        "        elif tuner == \"random\":\n",
        "            tuner_obj = RandomTuner(tsk)\n",
        "        elif tuner == \"gridsearch\":\n",
        "            tuner_obj = GridSearchTuner(tsk)\n",
        "        else:\n",
        "            raise ValueError(\"Invalid tuner: \" + tuner)\n",
        "\n",
        "        if use_transfer_learning:\n",
        "            if os.path.isfile(tmp_log_file):\n",
        "                tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file))\n",
        "\n",
        "        # do tuning\n",
        "        tsk_trial = min(n_trial, len(tsk.config_space))\n",
        "        tuner_obj.tune(\n",
        "            n_trial=tsk_trial,\n",
        "            early_stopping=early_stopping,\n",
        "            measure_option=measure_option,\n",
        "            callbacks=[\n",
        "                autotvm.callback.progress_bar(tsk_trial, prefix=prefix),\n",
        "                log_to_file(tmp_log_file),\n",
        "            ],\n",
        "        )\n",
        "\n",
        "    # pick best records to a cache file\n",
        "    autotvm.record.pick_best(tmp_log_file, log_filename)\n",
        "    os.remove(tmp_log_file)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "注册特定于 VTA 的调优任务："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def register_vta_tuning_tasks():\n",
        "    from tvm.autotvm.task import TaskExtractEnv\n",
        "\n",
        "    @tvm.te.tag_scope(tag=topi.tag.ELEMWISE)\n",
        "    def my_clip(x, a_min, a_max):\n",
        "        \"\"\"Unlike topi's current clip, put min and max into two stages.\"\"\"\n",
        "        const_min = tvm.tir.const(a_min, x.dtype)\n",
        "        const_max = tvm.tir.const(a_max, x.dtype)\n",
        "        x = te.compute(x.shape, lambda *i: tvm.te.min(x(*i), const_max), name=\"clipA\")\n",
        "        x = te.compute(x.shape, lambda *i: tvm.te.max(x(*i), const_min), name=\"clipB\")\n",
        "        return x\n",
        "\n",
        "    # init autotvm env to register VTA operator\n",
        "    TaskExtractEnv()\n",
        "\n",
        "    @autotvm.template(\"add.vta\")\n",
        "    def _topi_add(*args, **kwargs):\n",
        "        assert not kwargs, \"Do not support kwargs in template function call\"\n",
        "        A, B = args[:2]\n",
        "\n",
        "        with tvm.target.vta():\n",
        "            res = vta.top.op.add_packed(*args, **kwargs)\n",
        "            res = my_clip(res, 0, 127)\n",
        "            res = topi.cast(res, \"int8\")\n",
        "\n",
        "        if tvm.target.Target.current().device_name == \"vta\":\n",
        "            s = vta.top.op.schedule_add_packed([res])\n",
        "        else:\n",
        "            s = te.create_schedule([res.op])\n",
        "        return s, [A, B, res]\n",
        "\n",
        "    @autotvm.template(\"multiply.vta\")\n",
        "    def _topi_multiply(*args, **kwargs):\n",
        "        assert not kwargs, \"Do not support kwargs in template function call\"\n",
        "        A, B = args[:2]\n",
        "\n",
        "        with tvm.target.vta():\n",
        "            res = vta.top.op.multiply_packed(*args, **kwargs)\n",
        "            res = my_clip(res, 0, 127)\n",
        "            res = topi.cast(res, \"int8\")\n",
        "\n",
        "        if tvm.target.Target.current().device_name == \"vta\":\n",
        "            s = vta.top.op.schedule_multiply_packed([res])\n",
        "        else:\n",
        "            s = te.create_schedule([res.op])\n",
        "        return s, [A, B, res]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "最后，启动调优作业并评估端到端性能。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "ALU only op only available for intelfocl target\n"
          ]
        }
      ],
      "source": [
        "def tune_and_evaluate(tuning_opt):\n",
        "\n",
        "    if env.TARGET != \"intelfocl\":\n",
        "        print(\"ALU only op only available for intelfocl target\")\n",
        "        return\n",
        "\n",
        "    # Register VTA tuning tasks\n",
        "    register_vta_tuning_tasks()\n",
        "\n",
        "    # Perform task extraction on Relay program\n",
        "    print(\"Extract tasks...\")\n",
        "    relay_prog, params = compile_network(env, target, network, start_pack, stop_pack)\n",
        "    mod = tvm.IRModule.from_expr(relay_prog)\n",
        "    tasks = autotvm.task.extract_from_program(\n",
        "        mod,\n",
        "        params=params,\n",
        "        ops=(\n",
        "            relay.op.get(\"add\"),\n",
        "            relay.op.get(\"multiply\"),\n",
        "        ),\n",
        "        target=tvm.target.Target(target, host=env.target_host),\n",
        "    )\n",
        "\n",
        "    # filter out non-packed alu task\n",
        "    tasks = list(filter(lambda t: len(t.args[0][1]) > 4, tasks))\n",
        "    # filter out float alu task\n",
        "    tasks = list(filter(lambda t: t.args[0][2] != \"float32\", tasks))\n",
        "\n",
        "    # We should have extracted 10 convolution tasks\n",
        "    tasks_set = {}\n",
        "    print(\"Extracted {} alu tasks:\".format(len(tasks)))\n",
        "    for tsk in tasks:\n",
        "        print(\"tsk = \", tsk)\n",
        "\n",
        "        if len(tsk.args[1][1]) == 0:\n",
        "            args = list(tsk.args)\n",
        "            args[1] = args[0]\n",
        "            tsk.args = tuple(args)\n",
        "\n",
        "        if (tsk.name, tsk.args) in tasks_set:\n",
        "            print(\"task {} already exists\".format(tsk))\n",
        "        tasks_set[(tsk.name, tsk.args)] = tsk\n",
        "\n",
        "    tasks = list(tasks_set.values())\n",
        "    print(\"After merged, final #tasks={}, tasks = {}\".format(len(tasks), tasks))\n",
        "\n",
        "    # run tuning tasks\n",
        "    print(\"Tuning...\")\n",
        "    tune_tasks(tasks, **tuning_opt)\n",
        "\n",
        "\n",
        "# Run the tuning and evaluate the results\n",
        "tune_and_evaluate(tuning_option)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.10"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
